import torch
from tqdm import tqdm
import math
from torch.utils import data
import os.path
from os import path
import time
import torch.nn.functional as F
from AL_Dataset import ActiveLearning_Framework, idx_Dataset
import utils
import torch.optim as optim
import copy


def train_model(model, criterion, test_criterion, optimizer, device, ALset: ActiveLearning_Framework, 
                epoch_num, batch_size, weighted, patience=15):
    """
    for model training, optimization and testing
    :param: model: (nn.Module) current model
    :param: criterion: (nn.CrossEntropyLoss) loss function for training
    :param: test_criterion: (nn.CrossEntropyLoss) loss function for testing
    :param: optimizer: (optim.SGD)
    :param: device: (torch.device) GPU/CPU
    :param: ALset: (class: ActiveLearning_Framework)
    :param: epoch_num: (int) number of training epochs
    :param: batch_size: (int) dataloader batchsize
    :param: weighted: (bool) whether use weighted loss for training
    :param: patience: (int) early stop patience
    """
    trainset = ALset.get_train_dataset()
    idxset = idx_Dataset(len(trainset))
    best_validation_loss = math.inf
    early_stop_counter = 0
    model_save_PATH = './model/checkpoint.pt'

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

    for epoch in range(epoch_num):  # loop over the dataset multiple times
        running_loss = 0.0
        running_loss_regular = 0
        train_idx_loader = torch.utils.data.DataLoader(idxset, batch_size=batch_size, shuffle=True)

        # train
        model.train()
        for batch_count, indexes in enumerate(train_idx_loader, 0):
            # this trick is to get the batch and its indexes
            batchset = data.Subset(trainset, indexes)
            batch_loader = torch.utils.data.DataLoader(batchset, batch_size=len(batchset), shuffle=False)
            for inputs, labels in batch_loader:
                pass
            inputs, labels = inputs.to(device), labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs, _, _ = model(inputs)
            if weighted:
                losses = criterion(outputs, labels)  
                losses = losses * ALset.weights[indexes]
                loss = losses.mean()
            else:
                loss = criterion(outputs, labels)
            loss_regular = test_criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            running_loss_regular += loss_regular.item()
        training_loss = running_loss / batch_count
        if (epoch+1)%10 ==0:
            print('epoch %d loss: %f\n' % (epoch + 1, training_loss))
            print('regular loss is {}\n'.format(running_loss_regular / batch_count))

        #scheduler.step()  # turn on or turn off the scheduler
        
        # validate, early stop
        '''
        validate_loss = validate_model(model, test_criterion, device, ALset, batch_size)
        if validate_loss < best_validation_loss:
            best_validation_loss = validate_loss
            torch.save(model.state_dict(), model_save_PATH)
            early_stop_counter = 0
        else:
            if early_stop_counter >= patience:
                print('\n EARLY STOPPED!\n')
                break
            else:
                early_stop_counter += 1
        '''
        if (epoch+1)%20 ==0: 
            test_model(model, test_criterion, device, ALset, batch_size)
        
    #model.load_state_dict(torch.load(model_save_PATH))
    print('epoch %d loss: %.3f\n' % (epoch + 1, training_loss))
    print('regular loss is {}\n'.format(running_loss_regular / batch_count))
    print('Finished Training\n')
    return training_loss


def validate_model(model, criterion, device, ALset: ActiveLearning_Framework, batch_size):
    """
    for model validation
    :param: model: (nn.Module) current model
    :param: criterion: (nn.CrossEntropyLoss) loss function for training
    :param: device: (torch.device) GPU/CPU
    :param: ALset: (class: ActiveLearning_Framework)
    :param: batch_size: (int) dataloader batchsize
    """
    loss = 0
    loss_count = 0
    model.eval()
    validationset = ALset.get_validation_dataset()
    validationloader = torch.utils.data.DataLoader(validationset, batch_size=batch_size, shuffle=True)
    num_valid_data = len(validationset)
    print('Validating...')
    with torch.no_grad():
        for data in tqdm(validationloader):
            images, labels = data[0].to(device), data[1].to(device)
            outputs, _,_ = model(images)
            loss += criterion(outputs, labels)
            loss_count += 1
    loss /= loss_count
    print('Validation loss of the network on test data:{}'.format(loss))
    return loss.item()


def test_model(model, criterion, device, ALset: ActiveLearning_Framework, batch_size):
    """
    for model testing
    :param: model: (nn.Module) current model
    :param: criterion: (nn.CrossEntropyLoss) loss function for training
    :param: device: (torch.device) GPU/CPU
    :param: ALset: (class: ActiveLearning_Framework)
    :param: batch_size: (int) dataloader batchsize
    """
    total = 0
    correct = 0
    loss = 0
    loss_count = 0
    model.eval()
    testset = ALset.get_test_dataset()
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
    num_test_data = len(testset)
    print('Testing...')
    with torch.no_grad():
        for data in tqdm(testloader):
            images, labels = data[0].to(device), data[1].to(device)
            outputs, _,_ = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            loss += criterion(outputs, labels)
            loss_count += 1

    loss /= loss_count
    accuracy = 100 * correct / total
    print('total test data ', total)
    print('Accuracy of the network on test images: %f %%\n' % accuracy)
    print('Loss of the network on test images:{}\n'.format(loss))

    return accuracy, loss.item()
